Skip to content

Optimise cache and precision for DL training#186

Open
prockenschaub wants to merge 7 commits intorvandewater:developmentfrom
prockenschaub:speedups
Open

Optimise cache and precision for DL training#186
prockenschaub wants to merge 7 commits intorvandewater:developmentfrom
prockenschaub:speedups

Conversation

@prockenschaub
Copy link
Copy Markdown
Collaborator

@prockenschaub prockenschaub commented Apr 24, 2026

Addresses #184 and #185 to improve training on GPU cluster.

Also fixes everything to float32 for both memory efficiency as well as compatibility with MPS.

Summary by CodeRabbit

Release Notes

  • Refactor
    • Improved data loading performance through optimized caching and batch processing mechanisms.
    • Standardized numerical precision handling to ensure consistent tensor operations across datasets.
    • Enhanced memory efficiency by enabling pinned memory for data loading operations.

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Apr 24, 2026

Warning

Rate limit exceeded

@prockenschaub has exceeded the limit for the number of commits that can be reviewed per hour. Please wait 0 minutes and 20 seconds before requesting another review.

Your organization is not enrolled in usage-based pricing. Contact your admin to enable usage-based pricing to continue reviews beyond the rate limit, or try again in 0 minutes and 20 seconds.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 8ab1aab8-404e-499c-b7db-0dab07355fcc

📥 Commits

Reviewing files that changed from the base of the PR and between 816964f and 0ffe22d.

📒 Files selected for processing (2)
  • icu_benchmarks/data/loader.py
  • icu_benchmarks/models/train.py

Walkthrough

This PR refactors PredictionPolarsDataset to use a precomputed one-pass partitioning strategy with a new _build_item() helper method and cached item retrieval, while standardizing data type conversions to float32 across both Polars and Pandas dataset classes. Training configuration is updated to broaden precision parameter types and enable memory pinning in DataLoaders with revised matmul precision logic.

Changes

Cohort / File(s) Summary
Dataset Loader Refactoring
icu_benchmarks/data/loader.py
PredictionPolarsDataset now uses precomputed partitions (_stay_order, _feat_arrays, _label_arrays) keyed by GROUP with new _build_item(idx) method for per-stay samples instead of filtering dataframes in __getitem__. Public __getitem__ wrapped as cached-data dispatcher. Both PredictionPolarsDataset and PredictionPandasDataset updated to standardize labels/features/indicators to np.float32 and Tensor float32 types, removing mps-conditional branching. Return types of to_tensor() narrowed from Union[Tensor, np.ndarray] to Tensor.
Training Configuration
icu_benchmarks/models/train.py
train_common precision parameter type broadened from Literal["16-true"] to generic str to accept wider range of precision values. DataLoader creation updated to enable pin_memory=True for training and validation. Matmul precision configuration logic refactored to use membership check for 16/bf16/mixed variants instead of previous conditional, setting float32 matmul precision to "high" for relevant cases.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

🐰 The dataset hops with precomputed grace,
No more filtering in every place!
Float32 tensors, cached with care,
Precision broadened—config's fair!
From Polars' burrows to training ground,
Optimizations abound all around! 🌾

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly aligns with the main changes: dataset caching optimization via the new _build_item method and precision standardization to float32 across both dataset classes.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
icu_benchmarks/data/loader.py (2)

115-134: ⚠️ Potential issue | 🔴 Critical

Critical: in-place mutation of cached _label_arrays[stay_id] corrupts pad_mask on subsequent calls.

On Line 116, labels is a live reference to the array stored in self._label_arrays[stay_id]. When len(labels) > 1 and length_diff <= 0 (i.e., a stay whose length equals maxlen), neither the len(labels) == 1 nor the length_diff > 0 branch allocates a new array, so labels[not_labeled] = -1 on Line 133 mutates the cached entry in place.

After the first call, the cached labels have -1 where NaNs used to be. On any subsequent rebuild (e.g., when ram_cache=False, or if ram_cache(True) is ever called again), np.argwhere(np.isnan(labels)) returns an empty array, so the pad_mask[not_labeled] = 0 line on Line 134 is skipped and pad_mask is returned as all ones at positions that should have been masked out — silently feeding unlabeled timesteps into the loss.

window is never mutated, so only labels needs a defensive copy.

🐛 Proposed fix
         window = self._feat_arrays[stay_id]
-        labels = self._label_arrays[stay_id]
+        # Copy to avoid mutating the cached entry below (labels[not_labeled] = -1).
+        labels = self._label_arrays[stay_id].copy()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@icu_benchmarks/data/loader.py` around lines 115 - 134, The issue is in-place
mutation of the cached labels array (self._label_arrays[stay_id]) causing
persistent -1s that break pad_mask later; fix by making a defensive copy of the
labels before any mutation (e.g., assign labels =
self._label_arrays[stay_id].copy() or np.array(self._label_arrays[stay_id],
copy=True) immediately after retrieving it), and ensure any branches that build
new label arrays (the len(labels) == 1 concatenation and the length_diff > 0
padding branch) operate on and return this copied/allocated array so subsequent
writes like labels[not_labeled] = -1 do not mutate the cached self._label_arrays
entry (refer to variables stay_id, labels, self._label_arrays, not_labeled,
pad_mask).

174-195: ⚠️ Potential issue | 🔴 Critical

Critical: float32 cast of row_indicators corrupts stay-id precision in exported predictions.

row_indicators carries the GROUP (stay_id) and, for temporal data, SEQUENCE (time in hours). These are actively passed to _save_model_outputs() and exported to pred_indicators.csv for downstream analysis. float32 can only represent integers exactly up to 2²⁴ ≈ 16.7 million; hospital-derived stay ids (e.g., MIMIC-IV) routinely exceed this. Unconditionally casting to float32 on line 194 silently collapses distinct ids, corrupting the CSV export and breaking any attempt to match predictions back to original stays.

This is a regression: commit 52117f4 introduced conditional mps-based casting (if self.mps: ...else: native dtype), but the self.mps flag (line 53) is set but never used in to_tensor(). The condition was removed and replaced with unconditional float32 casting.

Only data and labels require float32 for tensor operations; row_indicators is metadata and should retain its native dtype.

Fix for to_tensor
     def to_tensor(self) -> tuple[Tensor, Tensor, Tensor]:
         data, labels, row_indicators = self.get_data_and_labels()
-        # Always use float32 for memory efficiency and MPS compatibility
         return (
             from_numpy(data),
             from_numpy(labels),
-            from_numpy(row_indicators.astype(np.float32)),
+            from_numpy(row_indicators),
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@icu_benchmarks/data/loader.py` around lines 174 - 195, The to_tensor method
currently forces row_indicators to float32 which corrupts stay_id precision;
update to_tensor (and/or get_data_and_labels usage) so only data and labels are
cast to float32 before creating tensors and row_indicators is passed through
with its native integer dtype (do not call .astype(np.float32) on
row_indicators). Use the existing self.mps flag if you need MPS-specific casting
for data/labels, but ensure row_indicators stays as the original numpy dtype
when calling from_numpy(row_indicators) so exported pred_indicators.csv keeps
exact stay_id/sequence values.
🧹 Nitpick comments (1)
icu_benchmarks/models/train.py (1)

114-131: Consider gating pin_memory on accelerator availability.

When cpu=True the trainer runs on CPU and pinning is wasted work (and can emit warnings on some backends). A trivial gate avoids this:

-        persistent_workers=persistent_workers,
-        pin_memory=True,
+        persistent_workers=persistent_workers,
+        pin_memory=not cpu,

…applied to both the train and val loaders (the test loader at Line 199 already runs inside the same cpu scope and could be tightened similarly).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@icu_benchmarks/models/train.py` around lines 114 - 131, The DataLoader
instances (train_loader and val_loader) set pin_memory=True unconditionally
which wastes work or triggers warnings when running on CPU; change both
DataLoader constructions to set pin_memory = (not cpu) or equivalent so pinning
is enabled only when an accelerator/GPU is used (reference DataLoader,
train_loader, val_loader and the cpu/trainer flag), and apply the same gating
pattern used for the test loader scope.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@icu_benchmarks/data/loader.py`:
- Around line 115-134: The issue is in-place mutation of the cached labels array
(self._label_arrays[stay_id]) causing persistent -1s that break pad_mask later;
fix by making a defensive copy of the labels before any mutation (e.g., assign
labels = self._label_arrays[stay_id].copy() or
np.array(self._label_arrays[stay_id], copy=True) immediately after retrieving
it), and ensure any branches that build new label arrays (the len(labels) == 1
concatenation and the length_diff > 0 padding branch) operate on and return this
copied/allocated array so subsequent writes like labels[not_labeled] = -1 do not
mutate the cached self._label_arrays entry (refer to variables stay_id, labels,
self._label_arrays, not_labeled, pad_mask).
- Around line 174-195: The to_tensor method currently forces row_indicators to
float32 which corrupts stay_id precision; update to_tensor (and/or
get_data_and_labels usage) so only data and labels are cast to float32 before
creating tensors and row_indicators is passed through with its native integer
dtype (do not call .astype(np.float32) on row_indicators). Use the existing
self.mps flag if you need MPS-specific casting for data/labels, but ensure
row_indicators stays as the original numpy dtype when calling
from_numpy(row_indicators) so exported pred_indicators.csv keeps exact
stay_id/sequence values.

---

Nitpick comments:
In `@icu_benchmarks/models/train.py`:
- Around line 114-131: The DataLoader instances (train_loader and val_loader)
set pin_memory=True unconditionally which wastes work or triggers warnings when
running on CPU; change both DataLoader constructions to set pin_memory = (not
cpu) or equivalent so pinning is enabled only when an accelerator/GPU is used
(reference DataLoader, train_loader, val_loader and the cpu/trainer flag), and
apply the same gating pattern used for the test loader scope.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 5bda1778-f36e-40e7-ac75-7b6de4aded30

📥 Commits

Reviewing files that changed from the base of the PR and between ef999f7 and 816964f.

📒 Files selected for processing (2)
  • icu_benchmarks/data/loader.py
  • icu_benchmarks/models/train.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant